# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Description of the simulation parameters (time, iteration ...)
Usage
-----
.. code::
# Initialize the simulation
s = Simulation(start=0.2, end=1., time_step=0.1)
# do some initialisation stuff with operators,
# print(initial state ...)
# time loop
s.initialize() # --> ready to start
while not s.is_over:
# operators work
op.apply(s)
io.apply(s)
...
# update time step (optional)
adapt_time.apply(s)
# Prepare next step
s.advance()
# end simulation (optional) to prepare io
s.finalize()
io.apply(s)
"""
import sys
import os
import numpy as np
from abc import ABCMeta, abstractmethod
from hysop import dprint, vprint
from hysop.constants import HYSOP_REAL
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.tools.htypes import first_not_None, to_set, check_instance
from hysop.tools.numpywrappers import npw
from hysop.tools.io_utils import IO, IOParams
from hysop.tools.string_utils import vprint_banner
from hysop.core.mpi import default_mpi_params
import numpy as np
[docs]
class Simulation:
"""Time-Simulation process description (time step, iteration ...)"""
def __init__(
self,
name=None,
start=0.0,
end=1.0,
nb_iter=None,
dt0=None,
max_iter=None,
t=None,
dt=None,
times_of_interest=None,
mpi_params=None,
quiet=False,
clamp_t_to_end=True,
restart=0,
**kwds,
):
"""
Parameters
----------
name: str, optional
Name of this simulation.
start, end: real, optional
Initial and final time for simulation.
nb_iter : int, optional
Number of iterations required.
dt0 : real, optional
Initial time step.
dt: ScalarParameter, optional
ScalarParameter that will be used as time step.
Cannot be a constant parameter.
t: ScalarParameter, optional
ScalarParameter that will be used as time.
Cannot be a constant parameter.
max_iter : int, optional
Maximum number of iterations allowed.
Defaults to 1e9.
times_of_interest: array-like of float
List of times ti where the simulation may
modify current timestep to get t=ti.
Mainly used by HDF_Writers for precise
time dependent dumping.
tstart < ti <= tend
Defaults to empty set.
clamp_t_to_end : bool, optional
Specify if Simulation adjst dt for last iteration to have t=end
restart : int, optional
Iteration number to start from.
Attributes
----------
time : real
current time (value at the end of the time step)
current_iteration : int
start at 0. See notes below.
is_over : bool
true if end or max_iter or nb_iter has been reached.
dt: hysop.parameter.scalar_parameter.ScalarParameter
The scalar parameter that may be updated.
t: hysop.parameter.scalar_parameter.ScalarParameter
The scalar parameter that represents time.
time_step: double
Value of the dt parameter.
time_of_interest: float
Current simulation time target.
Notes
-----
* all parameters are optional but either dt0 or nb_iter
must be set.
* If both dt0 and nb_iter are given, dt0 is not used.
* current_iteration = -1 means simulation has not started, and
self.time = self.start
* current_iteration = 0 after a call to initialize,
self.time = self.start + self.time_step, i.e. targeted time.
* self.current_iteration = k, runs between self.t and self.tkp1.
with self.tkp1 == self.time
"""
self.end = end
self.start = start
self.time = start
self.is_over = False
self.current_iteration = -1
if mpi_params is None:
mpi_params = default_mpi_params()
self.mpi_params = mpi_params
self._rank = mpi_params.rank
self._comm = mpi_params.comm
self.clamp_t_to_end = clamp_t_to_end
self._restart = restart
if nb_iter is not None:
self.nb_iter = nb_iter
msg = "------------------------------------------------\n"
msg += "Simulation warning : both nb_iter and dt0\n"
msg += " are given, time step input will be ignored.\n"
msg += "------------------------------------------------\n"
if (nb_iter is not None) and (dt0 is not None):
vprint(msg)
dt0 = (self.end - self.start) / self.nb_iter
elif dt0 is not None:
pass
elif dt is not None:
assert isinstance(dt, ScalarParameter), type(dt)
assert not dt.const, "dt cannot be a constant parameter."
assert dt0 is not None, "dt parameter given, but dt0 has not been given."
else:
raise ValueError("You must set nb_iter or dt0 value.")
msg = f'dt0={dt0}, start={start}, end={end}'
assert (dt0 > 0.0) and -1e-6 < ((end-start)-dt0)/dt0, msg
self._dt0 = dt0
dt_name = f"{name}_dt" if (name is not None) else "dt"
if dt is None:
dtype = t.dtype if (t is not None) else HYSOP_REAL
dt = ScalarParameter(
name=dt_name,
dtype=dtype,
min_value=np.finfo(dtype).eps,
initial_value=dt0,
quiet=quiet,
)
else:
dt.value = dt0
self.dt = dt
self.name = name
# backup initial time step, required to reset simulation.
self.max_iter = first_not_None(max_iter, 1e9)
# Starting time for the current iteration
if t is None:
t = ScalarParameter(
name="t", dtype=dt.dtype, initial_value=start, quiet=quiet
)
else:
assert isinstance(t, ScalarParameter), type(t)
assert not t.const, "t cannot be a constant parameter."
t.value = start
self.t = t
# tk+1 = t + dt
self.tkp1 = start + self.time_step
assert self.end > self.start, "Final time must be greater than initial time"
assert (
self.start + self.time_step
) <= self.end, "start + step is bigger than end."
# times of interest
times_of_interest = to_set(first_not_None(times_of_interest, []))
times_of_interest = tuple(sorted(times_of_interest))
for toi in times_of_interest:
assert self.start <= toi < self.end, toi
self.times_of_interest = times_of_interest
# Internal tolerance for timer
assert t.dtype == dt.dtype
assert t.dtype in (np.float32, np.float64)
self.tol = np.finfo(dt.dtype).eps
# True if initialize has been called.
self._is_ready = False
self._next_is_last = False
self._next_is_time_of_interest = False
self._parameters_to_write = []
def _get_time_step(self):
"""Get current timestep."""
return self.dt()
time_step = property(_get_time_step)
[docs]
def advance(self, dbg=None, plot_freq=10):
"""Proceed to next time.
* Advance time and iteration number.
* Compute the new timestep
* check if simulation is over.
"""
msg = "simu.initialize() must be called before"
msg += " the time simulation loop."
assert self._is_ready, msg
if self.is_over:
return
if dbg is not None:
if self.is_time_of_interest:
dbg(msg=f"t={self.t()}", nostack=True)
elif (plot_freq > 0) and ((self.current_iteration % plot_freq) == 0):
dbg.fig.suptitle(f"it={self.current_iteration+1}, t={self.t()}")
dbg.update()
for io_params, params, kwds in self._parameters_to_write:
if self.mpi_params.rank == io_params.io_leader:
if io_params.fileformat is IO.ASCII:
if (self.current_iteration % io_params.frequency) == 0:
kwds = kwds.copy()
f = kwds.pop("file")
formatter = kwds.pop("formatter")
values = npw.asarray(
tuple(
map(
lambda x: x.item() if x.size == 1 else x,
(p() for p in params),
)
)
)
values = npw.array2string(
values,
max_line_width=npw.inf,
formatter=formatter,
legacy="1.13",
**kwds,
)
values = "\n" + values[1:-1]
f.write(values)
f.flush()
else:
msg = f"Unknown format {io_params.fileformat}."
raise ValueError(msg)
if self._next_is_last:
self.is_over = True
return
self._comm.Barrier()
self.update_time(self.tkp1)
all_t = self.mpi_params.comm.gather(self.t(), root=0)
if self.mpi_params.rank == 0:
assert np.allclose(all_t, all_t[0])
self.is_time_of_interest = False
if self.target_time_of_interest is not None:
if abs(self.tkp1 - self.target_time_of_interest) <= self.tol:
self.next_time_of_interest()
self.is_time_of_interest = True
self.tkp1 = self.t() + self.time_step
if abs(self.tkp1 - self.end) <= self.tol:
self._next_is_last = True
elif self.tkp1 > self.end:
self._next_is_last = True
if self.clamp_t_to_end:
msg = "** Next iteration is last iteration, clamping dt to achieve t={}. **"
msg = msg.format(self.end)
if self.mpi_params.rank == 0:
vprint()
self._print_banner(msg)
self.tkp1 = self.end
self.update_time_step(self.end - self.t())
elif (self.target_time_of_interest is not None) and (
self.tkp1 + self.tol >= self.target_time_of_interest
):
msg = "** Next iteration is a time of interest, clamping dt to achieve t={}. **"
msg = msg.format(self.target_time_of_interest)
if self.mpi_params.rank == 0:
vprint()
self._print_banner(msg)
self.tkp1 = self.target_time_of_interest
self.update_time_step(self.target_time_of_interest - self.t())
self._last_forced_timestep = self.dt()
elif self.dt() == self._last_forced_timestep:
self.update_time_step(self._dt0)
self._last_forced_timestep = None
self.current_iteration += 1
self.time = self.tkp1
if self.current_iteration + 2 > self.max_iter:
msg = "** Next iteration will be the last because max_iter={} will be achieved. **"
msg = msg.format(self.max_iter)
if self._rank == 0:
vprint()
self._print_banner(msg)
self._next_is_last = True
self.is_time_of_interest = True
all_dt = self._comm.gather(self.dt(), root=0)
if self._rank == 0:
assert np.allclose(all_dt, all_dt[0])
def _print_banner(self, msg):
vprint_banner(msg)
[docs]
def next_time_of_interest(self):
toi_counter = self.toi_counter
times_of_interest = self.times_of_interest
if toi_counter < len(times_of_interest):
self.target_time_of_interest = times_of_interest[toi_counter]
self.toi_counter += 1
else:
self.target_time_of_interest = None
[docs]
def update_time_step(self, dt):
"""Update time step for the next iteration
Parameters
-----------
dt: double
the new time step
Notes
-----
Warning : since it does not update tkp1 and time,
it must be called at the end of the time loop, just
before 'advance'.
"""
self.dt.set_value(dt)
[docs]
def update_time(self, t):
self.t.set_value(t)
[docs]
def initialize(self):
"""(Re)set simulation to initial values
--> back to iteration 0 and ready to run.
"""
tstart, tend = self.start, self.end
times_of_interest = self.times_of_interest
self.toi_counter = 0
self.next_time_of_interest()
self.is_time_of_interest = False
assert (tend - tstart) >= self.tol
if self.target_time_of_interest is not None:
assert tend >= self.target_time_of_interest >= tstart
if abs(self.target_time_of_interest - tstart) <= self.tol:
self.next_time_of_interest()
self.is_time_of_interest = True
if self.target_time_of_interest is not None:
assert tend >= self.target_time_of_interest > tstart
dt0 = min(self._dt0, self.target_time_of_interest - tstart)
else:
dt0 = self._dt0
self.update_time(tstart)
self.update_time_step(dt0)
self.tkp1 = tstart + self.time_step
assert self.tkp1 <= tend
if abs(self.tkp1 - self.end) <= self.tol:
self._next_is_last = True
else:
self._next_is_last = False
self.time = self.tkp1
self.is_over = False
self.current_iteration = self._restart
self._is_ready = True
self._last_forced_timestep = None
for io_params, params, kwds in self._parameters_to_write:
filename = io_params.filename
fileformat = io_params.fileformat
if ("file" in kwds) and (kwds["file"] is not None):
kwds["file"].close()
if self.mpi_params.rank == io_params.io_leader:
if os.path.isfile(filename) and self._restart == 0:
os.remove(filename)
if fileformat is IO.ASCII:
f = open(filename, "a")
header = "{}\n".format(
"\t".join(
"{}({})".format(p.name, p.pretty_name) for p in params
)
)
if self._restart == 0:
f.write(header)
kwds["file"] = f
formatter = {"float_kind": lambda x: f"{x:.8g}"}
kwds.setdefault("formatter", formatter)
else:
msg = f"Unknown format {fileformat}."
raise ValueError(msg)
[docs]
def finalize(self):
"""Use this function when you need to call an hdf i/o operator
after the end of the time-loop.
"""
self.is_over = True
self.current_iteration = -1
for io_params, params, kwds in self._parameters_to_write:
if self.mpi_params.rank == io_params.io_leader:
f = kwds.pop("file")
f.close()
[docs]
def print_state(self, verbose=None):
"""Print current simulation parameters"""
msg = "== Iteration : {0:3d}, from t = {1:6.8} to t = {2:6.8f} =="
if verbose:
print(msg.format(self.current_iteration, self.t(), self.time))
else:
vprint(msg.format(self.current_iteration, self.t(), self.time))
[docs]
def write_parameters(self, *params, **kwds):
if "io_params" not in kwds:
assert "filename" in kwds, "io_params or filename should be specified."
filename = kwds.pop("filename")
filepath = kwds.pop("filepath", None)
fileformat = kwds.pop("fileformat", IO.ASCII)
frequency = kwds.pop("frequency", 1)
io_leader = kwds.pop("io_leader", 0)
visu_leader = kwds.pop("visu_leader", 0)
io_params = IOParams(
filename=filename,
filepath=filepath,
frequency=frequency,
fileformat=fileformat,
io_leader=io_leader,
visu_leader=visu_leader,
)
else:
io_params = kwds.pop("io_params")
_params = ()
for p in params:
if p is None:
continue
for _, param in p.iterviews():
_params += (param,)
params = _params
self._parameters_to_write.append((io_params, params, kwds))
[docs]
def save_checkpoint(self, datagroup, mpi_params, io_params, compressor):
import zarr
check_instance(datagroup, zarr.hierarchy.Group)
is_io_leader = mpi_params.rank == io_params.io_leader
if is_io_leader:
# we need to export simulation parameter values because they
# may not be part of global problem parameters
datagroup.attrs["t"] = float(self.t())
datagroup.attrs["dt"] = float(self.dt())
for attrname in ("current_iteration", "tkp1", "time"):
data = getattr(self, attrname)
try:
data = data.item()
except AttributeError:
pass
datagroup.attrs[attrname] = data
[docs]
def load_checkpoint(self, datagroup, mpi_params, io_params, relax_constraints):
import zarr
check_instance(datagroup, zarr.hierarchy.Group)
self.times_of_interest = tuple(
sorted(
filter(lambda t: t >= datagroup.attrs["time"], self.times_of_interest)
)
)
self.toi_counter = 0
self.next_time_of_interest()
self.t._value[...] = datagroup.attrs["t"] # silent parameter update
self.dt._value[...] = datagroup.attrs["dt"] # silent parameter update
for attrname in ("current_iteration", "tkp1", "time"):
setattr(self, attrname, datagroup.attrs[attrname])
def __str__(self):
s = "Simulation parameters : "
s += "from " + str(self.start) + " to " + str(self.end)
s += ", time step : " + str(self.time_step)
s += ", current time : " + str(self.time) + ", iteration number : "
s += str(self.current_iteration) + ", max number of iterations : "
s += str(self.max_iter)
return s